""" Import libraries """
import argparse
import os

import time
import copy
import numpy as np
from tensorforce.agents import Agent
from tensorforce.environments import Environment
import matplotlib.pyplot as plt
import tensorflow as tf

from model.env_rl_helper import RLModelHelper
from utils.init_net_helper import ini_net, save_missing_arguments

""" Parser related """
parser = argparse.ArgumentParser(description="run_file")
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--respath', default=0, type=int)
args = parser.parse_args()
np.random.seed(args.seed)
tf.random.set_seed(args.seed)

"""
The experiment
"""
""" Create the environment and the RL agent """
env = Environment.create(environment=RLModelHelper, max_episode_timesteps=200)
customize_network = False
if customize_network:
    print('Keras network is used in agent')
    network = ini_net('None')
    agent = Agent.create(agent='ppo', network=dict(type='keras', model=network), environment=env, batch_size=10, learning_rate=0.0001)
else:
    print('Auto network is used in agent')
    agent = Agent.create(agent='ppo', environment=env, batch_size=10, learning_rate=0.0001)

""" Initialize simulation """
score_history = []
score_history_true = []
actions_history = []
states_history = []
bad_state_transitions_history = []
n_episodes = 10000
num_updates = 0

""" Train the RL agent """
time_train = -time.time()
for _ in range(n_episodes):
    env.actual_reset = True
    states = env.reset()  # Resample the initial state
    terminal = False
    actions_episode = []
    states_episode = [states]
    bad_state_transitions_episode = []
    score = 0
    score_true = 0
    reset_x = 0
    while not terminal:
        previous_state = copy.deepcopy(states)
        actions = agent.act(states=states)  # Actor
        states, terminal, reward = env.execute(actions=actions)  # Advance one time-step
        num_updates += agent.observe(terminal=terminal, reward=reward)  # Observes reward and whether a terminal state is reached, needs to be preceded by act()
        score += reward
        score_true += reward
        actions_episode.append(actions['Tc'])
        states_episode.append(states)
        # Analysis. Want to know which state transition gives bad reward
        if (np.matmul(env.hrepABig, states) > env.hrepBBig + env.tol).any():  # This also include the case that states is violating the physical constraint
            bad_state_transitions_episode.append([env.timestep, previous_state, actions['Tc'], reward, states])
            score_true -= reward
            score_true += -1000
        # if (states < env.xlb).any() or (states > env.xub).any():
            states = copy.deepcopy(previous_state)
            env.current_state = copy.deepcopy(previous_state)
            reset_x += 1

    print(f"Number of steps in this episode is {env.timestep}, Reset state {reset_x} times")
    score_history.append(score)
    score_history_true.append(score_true)
    bad_state_transitions_history.append(bad_state_transitions_episode)
    # if _%100 == 0:  # We cannot store all episodes due to memory limitation
    actions_history.append(actions_episode)
    states_history.append(states_episode)
    print(f"Episode number {_+1} and the score is {score} and the average score is {np.mean(score_history[-100:])}")
    pass

time_train += time.time()
print('Training time:', time_train/60, 'min')

""" Save the agent and the reward trajectory """
fname = agent.spec['agent'] + '_cstr_' + '0p0001' + 'lr_' + str(n_episodes) + 'episodes'
respth = 'results/seed_{}_train/'
dir = respth.format(args.seed)
try:
    os.mkdir(dir)
except OSError as error:
    print(error)  
figure_file = dir + fname
figure_avg_file = dir + fname + '_avg'
scores_file = dir + 'scores_' + fname
scores_true_file = dir + 'scores_true_' + fname
actions_file = dir + 'actions_' + fname
states_file = dir + 'states_' + fname
bad_states_file = dir + 'bad_states_' + fname
np.save(scores_file, score_history)
np.save(scores_true_file, score_history_true)
np.save(actions_file, actions_history)
np.save(states_file, states_history)
np.save(bad_states_file, bad_state_transitions_history)
agent.save(dir, filename=fname)
if customize_network:
    save_missing_arguments(agent, dir)

plt.figure()
plt.plot(score_history)
plt.ylabel('Score')
plt.xlabel('Episode')
plt.savefig(figure_file)
# plt.show()

x = [i + 1 for i in range(n_episodes)]
env.utils.plot_learning_curve(score_history, x, figure_avg_file)

plt.figure()
plt.plot(score_history_true)
plt.ylabel('Score')
plt.xlabel('Episode')
plt.savefig(figure_file+'true')
# plt.show()

x = [i + 1 for i in range(n_episodes)]
env.utils.plot_learning_curve(score_history_true, x, figure_avg_file+'true')

print('Done')